Note
Go to the end to download the full example code
IBL’s Recordings from the striatum of a mouse performing a visual discrimination task¶
In this notebook we download publically available data from the International Brain Laboratory, epoch it, run svGPFA and plot its results.
1. Setup environment¶
Import required packages¶
import sys
import warnings
import pickle
import time
import configparser
import numpy as np
import pandas as pd
import torch
from one.api import ONE
import brainbox.io.one
import iblUtils
import gcnu_common.utils.neural_data_analysis
import gcnu_common.stats.pointProcesses.tests
import gcnu_common.utils.config_dict
import svGPFA.stats.svGPFAModelFactory
import svGPFA.stats.svEM
import svGPFA.utils.miscUtils
import svGPFA.utils.initUtils
import svGPFA.plot.plotUtilsPlotly
Set data parameters¶
eID = "ebe2efe3-e8a1-451a-8947-76ef42427cc9"
probe_id = "probe00"
epoch_event_name = "response_times"
clusters_ids_filename = "../init/clustersIDs_40_64.csv"
trials_ids_filename = "../init/trialsIDs_0_89.csv"
min_neuron_trials_avg_firing_rate = 0.1
Set estimation hyperparameters¶
n_latents = 10
em_max_iter_dyn = 200
common_n_ind_points = 15
est_init_number = 0
est_init_config_filename_pattern = "../init/{:08d}_IBL_estimation_metaData.ini"
model_save_filename = "../results/stiatum_ibl_model.pickle"
2. Epoch¶
Download data¶
one = ONE(base_url='https://openalyx.internationalbrainlab.org',
password='international', silent=True)
spikes = one.load_object(eID, 'spikes', 'alf/probe00/pykilosort')
clusters = one.load_object(eID, "clusters", f"alf/{probe_id}/pykilosort")
trials = one.load_object(eID, 'trials')
Extract variables of interest¶
clusters_ids = np.unique(spikes.clusters.tolist())
n_clusters = len(clusters_ids)
channels_for_clusters_ids = clusters.channels
els = brainbox.io.one.load_channel_locations(eID, one=one)
locs_for_clusters_ids = els[probe_id]["acronym"][channels_for_clusters_ids].tolist()
Epoch spikes times¶
epoch_times = trials[epoch_event_name]
n_trials = len(epoch_times)
epoch_start_times = [trials["intervals"][r][0] for r in range(n_trials)]
epoch_end_times = [trials["intervals"][r][1] for r in range(n_trials)]
spikes_times_by_neuron = []
for cluster_id in clusters_ids:
print(f"Processing cluster {cluster_id}")
neuron_spikes_times = spikes.times[spikes.clusters==cluster_id]
n_epoched_spikes_times = iblUtils.epoch_neuron_spikes_times(
neuron_spikes_times=neuron_spikes_times,
epoch_times = epoch_times,
epoch_start_times=epoch_start_times,
epoch_end_times=epoch_end_times)
spikes_times_by_neuron.append(n_epoched_spikes_times)
spikes_times = [[spikes_times_by_neuron[n][r] for n in range(n_clusters)]
for r in range(n_trials)]
trials_start_times = [epoch_start_times[r]-epoch_times[r] for r in range(n_trials)]
trials_end_times = [epoch_end_times[r]-epoch_times[r] for r in range(n_trials)]
n_neurons = len(spikes_times[0])
Processing cluster 0
Processing cluster 1
Processing cluster 2
Processing cluster 3
Processing cluster 4
Processing cluster 5
Processing cluster 6
Processing cluster 7
Processing cluster 8
Processing cluster 9
Processing cluster 10
Processing cluster 11
Processing cluster 12
Processing cluster 13
Processing cluster 14
Processing cluster 15
Processing cluster 16
Processing cluster 17
Processing cluster 18
Processing cluster 19
Processing cluster 20
Processing cluster 21
Processing cluster 22
Processing cluster 23
Processing cluster 24
Processing cluster 25
Processing cluster 26
Processing cluster 27
Processing cluster 28
Processing cluster 29
Processing cluster 30
Processing cluster 31
Processing cluster 32
Processing cluster 33
Processing cluster 34
Processing cluster 35
Processing cluster 36
Processing cluster 37
Processing cluster 38
Processing cluster 39
Processing cluster 40
Processing cluster 41
Processing cluster 42
Processing cluster 43
Processing cluster 44
Processing cluster 45
Processing cluster 46
Processing cluster 47
Processing cluster 48
Processing cluster 49
Processing cluster 50
Processing cluster 51
Processing cluster 52
Processing cluster 53
Processing cluster 54
Processing cluster 55
Processing cluster 56
Processing cluster 57
Processing cluster 58
Processing cluster 59
Processing cluster 60
Processing cluster 61
Processing cluster 62
Processing cluster 63
Processing cluster 64
Processing cluster 65
Processing cluster 66
Processing cluster 67
Processing cluster 68
Processing cluster 69
Processing cluster 70
Processing cluster 71
Processing cluster 72
Processing cluster 73
Processing cluster 74
Processing cluster 75
Processing cluster 76
Processing cluster 77
Processing cluster 78
Processing cluster 79
Processing cluster 80
Processing cluster 81
Processing cluster 82
Processing cluster 83
Processing cluster 84
Processing cluster 85
Processing cluster 86
Processing cluster 87
Processing cluster 88
Processing cluster 89
Processing cluster 90
Processing cluster 91
Processing cluster 92
Processing cluster 93
Processing cluster 94
Processing cluster 95
Processing cluster 96
Processing cluster 97
Processing cluster 98
Processing cluster 99
Processing cluster 100
Processing cluster 101
Processing cluster 102
Processing cluster 103
Processing cluster 104
Processing cluster 105
Processing cluster 106
Processing cluster 107
Processing cluster 108
Processing cluster 109
Processing cluster 110
Processing cluster 111
Processing cluster 112
Processing cluster 113
Processing cluster 114
Processing cluster 115
Processing cluster 116
Processing cluster 117
Processing cluster 118
Processing cluster 119
Processing cluster 120
Processing cluster 121
Processing cluster 122
Processing cluster 123
Processing cluster 124
Processing cluster 125
Processing cluster 126
Processing cluster 127
Processing cluster 128
Processing cluster 129
Processing cluster 130
Processing cluster 131
Processing cluster 132
Processing cluster 133
Processing cluster 134
Processing cluster 135
Processing cluster 136
Processing cluster 137
Processing cluster 138
Processing cluster 139
Processing cluster 140
Processing cluster 141
Processing cluster 142
Processing cluster 143
Processing cluster 144
Processing cluster 145
Processing cluster 146
Processing cluster 147
Processing cluster 148
Processing cluster 149
Processing cluster 150
Processing cluster 151
Processing cluster 152
Processing cluster 153
Processing cluster 154
Processing cluster 155
Processing cluster 156
Processing cluster 157
Processing cluster 158
Processing cluster 159
Processing cluster 160
Processing cluster 161
Processing cluster 162
Processing cluster 163
Processing cluster 164
Processing cluster 165
Processing cluster 166
Processing cluster 167
Processing cluster 168
Processing cluster 169
Processing cluster 170
Processing cluster 171
Processing cluster 172
Processing cluster 173
Processing cluster 174
Processing cluster 175
Processing cluster 176
Processing cluster 177
Processing cluster 178
Processing cluster 179
Processing cluster 180
Processing cluster 181
Processing cluster 182
Processing cluster 183
Processing cluster 184
Processing cluster 185
Processing cluster 186
Processing cluster 187
Processing cluster 188
Processing cluster 189
Processing cluster 190
Processing cluster 191
Processing cluster 192
Processing cluster 193
Processing cluster 194
Processing cluster 195
Processing cluster 196
Processing cluster 197
Processing cluster 198
Processing cluster 199
Processing cluster 200
Processing cluster 201
Processing cluster 202
Processing cluster 203
Processing cluster 204
Processing cluster 205
Processing cluster 206
Processing cluster 207
Processing cluster 208
Processing cluster 209
Processing cluster 210
Processing cluster 211
Processing cluster 212
Processing cluster 213
Processing cluster 214
Processing cluster 215
Processing cluster 216
Processing cluster 217
Processing cluster 218
Processing cluster 219
Processing cluster 220
Processing cluster 221
Processing cluster 222
Processing cluster 223
Processing cluster 224
Processing cluster 225
Processing cluster 226
Processing cluster 227
Processing cluster 228
Processing cluster 229
Processing cluster 230
Processing cluster 231
Processing cluster 232
Processing cluster 233
Processing cluster 234
Processing cluster 235
Processing cluster 236
Processing cluster 237
Processing cluster 238
Processing cluster 239
Processing cluster 240
Processing cluster 241
Processing cluster 242
Processing cluster 243
Processing cluster 244
Processing cluster 245
Processing cluster 246
Processing cluster 247
Processing cluster 248
Processing cluster 249
Processing cluster 250
Processing cluster 251
Processing cluster 252
Processing cluster 253
Processing cluster 254
Processing cluster 255
Processing cluster 256
Processing cluster 257
Processing cluster 258
Processing cluster 259
Processing cluster 260
Processing cluster 261
Processing cluster 262
Processing cluster 263
Processing cluster 264
Processing cluster 265
Processing cluster 266
Processing cluster 267
Processing cluster 268
Processing cluster 269
Processing cluster 270
Processing cluster 271
Processing cluster 272
Processing cluster 273
Processing cluster 274
Processing cluster 275
Processing cluster 276
Processing cluster 277
Processing cluster 278
Processing cluster 279
Processing cluster 280
Processing cluster 281
Processing cluster 282
Processing cluster 283
Processing cluster 284
Processing cluster 285
Processing cluster 286
Subset epoched spikes times¶
# subset selected_clusters_ids
selected_clusters_ids = np.genfromtxt(clusters_ids_filename, dtype=np.uint64)
spikes_times = iblUtils.subset_clusters_ids_data(
selected_clusters_ids=selected_clusters_ids,
clusters_ids=clusters_ids,
spikes_times=spikes_times,
)
n_neurons = len(spikes_times[0])
n_trials = len(spikes_times)
trials_ids = np.arange(n_trials)
# subset selected_trials_ids
selected_trials_ids = np.genfromtxt(trials_ids_filename, dtype=np.uint64)
spikes_times, trials_start_times, trials_end_times = \
iblUtils.subset_trials_ids_data(
selected_trials_ids=selected_trials_ids,
trials_ids=trials_ids,
spikes_times=spikes_times,
trials_start_times=trials_start_times,
trials_end_times=trials_end_times)
n_trials = len(spikes_times)
# remove units with low spike rate
neurons_indices = torch.arange(n_neurons)
trials_durations = [trials_end_times[i] - trials_start_times[i]
for i in range(n_trials)]
spikes_times, neurons_indices = \
gcnu_common.utils.neural_data_analysis.removeUnitsWithLessTrialAveragedFiringRateThanThr(
spikes_times=spikes_times, neurons_indices=neurons_indices,
trials_durations=trials_durations,
min_neuron_trials_avg_firing_rate=min_neuron_trials_avg_firing_rate)
selected_clusters_ids = [selected_clusters_ids[i] for i in neurons_indices]
n_trials = len(spikes_times)
n_neurons = len(spikes_times[0])
Check that spikes have been epoched correctly¶
Plot spikes¶
Plot the spikes of all trials of a randomly chosen neuron. Most trials should contain at least one spike.
neuron_to_plot_index = torch.randint(low=0, high=n_neurons, size=(1,)).item()
fig = svGPFA.plot.plotUtilsPlotly.getSpikesTimesPlotOneNeuron(
spikes_times=spikes_times,
neuron_index=neuron_to_plot_index,
title=f"Neuron index: {neuron_to_plot_index}",
)
fig
Run some simple checks on spikes¶
The function checkEpochedSpikesTimes tests that:
every neuron fired at least one spike across all trials,
for each trial, the spikes times of every neuron are between the trial start and end times.
If any check fails, a ValueError will be raised. Otherwise a checks
passed message should be printed.
try:
gcnu_common.utils.neural_data_analysis.checkEpochedSpikesTimes(
spikes_times=spikes_times, trials_start_times=trials_start_times,
trials_end_times=trials_end_times,
)
except ValueError:
raise
print("Checks passed")
Checks passed
3. Get parameters¶
Dynamic parameters specification¶
dynamic_params_spec = {
"optim_params": {"em_max_iter": em_max_iter_dyn},
"ind_points_locs_params0": {"common_n_ind_points": common_n_ind_points},
}
Config file parameters specification¶
The configuration file appears here
args_info = svGPFA.utils.initUtils.getArgsInfo()
est_init_config_filename = est_init_config_filename_pattern.format(
est_init_number)
est_init_config = configparser.ConfigParser()
est_init_config.read(est_init_config_filename)
strings_dict = gcnu_common.utils.config_dict.GetDict(
config=est_init_config).get_dict()
config_file_params_spec = \
svGPFA.utils.initUtils.getParamsDictFromStringsDict(
n_latents=n_latents, n_trials=n_trials,
strings_dict=strings_dict, args_info=args_info)
Get the parameters from the dynamic and configuration file parameter specifications¶
params, kernels_types = svGPFA.utils.initUtils.getParamsAndKernelsTypes(
n_trials=n_trials, n_neurons=n_neurons, n_latents=n_latents,
trials_start_times=trials_start_times,
trials_end_times=trials_end_times,
dynamic_params_spec=dynamic_params_spec,
config_file_params_spec=config_file_params_spec)
Extracted config_file_params_spec[optim_params][n_quad]=200
Extracted dynamic_params_spec[ind_points_locs_params0][common_n_ind_points]=15
Extracted from config_file c0_distribution=Normal, c0_loc=0.0, c0_scale=1.0, c0_random_seed=None
Extracted from config_file d0_distribution=Normal, d0_loc=0.0, d0_scale=1.0, d0_random_seed=None
Extracted from config_file k_type=exponentialQuadratic and k_lengthsales0=0.3
Extracted from config_file ind_points_locs0_layout=equidistant
Extracted from config_file variational_mean0_constant_value=0.0
Extracted from config_file variational_cov0_diag_value=0.01
Extracted config_file_params_spec[optim_params][n_quad]=200
Extracted config_file_params_spec[optim_params][prior_cov_reg_param]=0.001
Extracted config_file_params_spec[optim_params][optim_method]=ECM
Extracted dynamic_params_spec[optim_params][em_max_iter]=200
Extracted config_file_params_spec[optim_params][verbose]=True
Extracted config_file_params_spec[optim_params][estep_estimate]=True
Extracted config_file_params_spec[optim_params][estep_max_iter]=20
Extracted config_file_params_spec[optim_params][estep_lr]=1.0
Extracted config_file_params_spec[optim_params][estep_tolerance_grad]=0.001
Extracted config_file_params_spec[optim_params][estep_tolerance_change]=1e-05
Extracted config_file_params_spec[optim_params][estep_line_search_fn]=strong_wolfe
Extracted config_file_params_spec[optim_params][mstep_embedding_estimate]=True
Extracted config_file_params_spec[optim_params][mstep_embedding_max_iter]=20
Extracted config_file_params_spec[optim_params][mstep_embedding_lr]=1.0
Extracted config_file_params_spec[optim_params][mstep_embedding_tolerance_grad]=0.001
Extracted config_file_params_spec[optim_params][mstep_embedding_tolerance_change]=1e-05
Extracted config_file_params_spec[optim_params][mstep_embedding_line_search_fn]=strong_wolfe
Extracted config_file_params_spec[optim_params][mstep_kernels_estimate]=True
Extracted config_file_params_spec[optim_params][mstep_kernels_max_iter]=20
Extracted config_file_params_spec[optim_params][mstep_kernels_lr]=1.0
Extracted config_file_params_spec[optim_params][mstep_kernels_tolerance_grad]=0.001
Extracted config_file_params_spec[optim_params][mstep_kernels_tolerance_change]=1e-05
Extracted config_file_params_spec[optim_params][mstep_kernels_line_search_fn]=strong_wolfe
Extracted config_file_params_spec[optim_params][mstep_indpointslocs_estimate]=True
Extracted config_file_params_spec[optim_params][mstep_indpointslocs_max_iter]=20
Extracted config_file_params_spec[optim_params][mstep_indpointslocs_lr]=1.0
Extracted config_file_params_spec[optim_params][mstep_indpointslocs_tolerance_grad]=0.001
Extracted config_file_params_spec[optim_params][mstep_indpointslocs_tolerance_change]=1e-05
Extracted config_file_params_spec[optim_params][mstep_indpointslocs_line_search_fn]=strong_wolfe
4. Estimate svGPFA model¶
Create kernels, a model and set its initial parameters¶
Build kernels¶
kernels_params0 = params["initial_params"]["posterior_on_latents"]["kernels_matrices_store"]["kernels_params0"]
kernels = svGPFA.utils.miscUtils.buildKernels(
kernels_types=kernels_types, kernels_params=kernels_params0)
Create model¶
kernelMatrixInvMethod = svGPFA.stats.svGPFAModelFactory.kernelMatrixInvChol
indPointsCovRep = svGPFA.stats.svGPFAModelFactory.indPointsCovChol
model = svGPFA.stats.svGPFAModelFactory.SVGPFAModelFactory.buildModelPyTorch(
conditionalDist=svGPFA.stats.svGPFAModelFactory.PointProcess,
linkFunction=svGPFA.stats.svGPFAModelFactory.ExponentialLink,
embeddingType=svGPFA.stats.svGPFAModelFactory.LinearEmbedding,
kernels=kernels, kernelMatrixInvMethod=kernelMatrixInvMethod,
indPointsCovRep=indPointsCovRep)
Set initial parameters¶
model.setParamsAndData(
measurements=spikes_times,
initial_params=params["initial_params"],
eLLCalculationParams=params["ell_calculation_params"],
priorCovRegParam=params["optim_params"]["prior_cov_reg_param"])
Maximize the Lower Bound¶
(Warning: with the parameters above, this step takes around 5 minutes for 30 em_max_iter)
# svEM = svGPFA.stats.svEM.SVEM_PyTorch()
# tic = time.perf_counter()
# lowerBoundHist, elapsedTimeHist, terminationInfo, iterationsModelParams = \
# svEM.maximize(model=model, optim_params=params["optim_params"],
# method=params["optim_params"]["optim_method"], out=sys.stdout)
# toc = time.perf_counter()
# print(f"Elapsed time {toc - tic:0.4f} seconds")
# resultsToSave = {"lowerBoundHist": lowerBoundHist,
# "elapsedTimeHist": elapsedTimeHist,
# "terminationInfo": terminationInfo,
# "iterationModelParams": iterationsModelParams,
# "model": model}
# with open(model_save_filename, "wb") as f:
# pickle.dump(resultsToSave, f)
# print("Saved results to {:s}".format(model_save_filename))
with open(model_save_filename, "rb") as f:
estResults = pickle.load(f)
lowerBoundHist = estResults["lowerBoundHist"]
elapsedTimeHist = estResults["elapsedTimeHist"]
model = estResults["model"]
5. Goodness-of-fit analysis¶
Set goodness-of-fit variables¶
ksTestGamma = 10
trial_for_gof = 0
cluster_id_for_gof = 41
n_time_steps_IF = 100
cluster_id_for_gof_index = torch.nonzero(torch.IntTensor(selected_clusters_ids)==cluster_id_for_gof)
trials_times = svGPFA.utils.miscUtils.getTrialsTimes(
start_times=trials_start_times,
end_times=trials_end_times,
n_steps=n_time_steps_IF)
Calculate expected intensity function values (for KS test and IF plots)¶
with torch.no_grad():
cif_values = model.computeExpectedPosteriorCIFs(times=trials_times)
cif_values_GOF = cif_values[trial_for_gof][cluster_id_for_gof_index]
Perform a time-rescaling KS test (with numerical correction)¶
trial_times_GOF = trials_times[trial_for_gof, :, 0]
spikes_times_GOF = spikes_times[trial_for_gof][cluster_id_for_gof_index]
if len(spikes_times_GOF) == 0:
raise ValueError("No spikes found for goodness-of-fit analysis")
title = "Trial {:d}, Neuron {:d} ({:d} spikes)".format(
trial_for_gof, cluster_id_for_gof, len(spikes_times_GOF))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
diffECDFsX, diffECDFsY, estECDFx, estECDFy, simECDFx, simECDFy, cb = \
gcnu_common.stats.pointProcesses.tests.\
KSTestTimeRescalingNumericalCorrection(spikes_times=spikes_times_GOF,
cif_times=trial_times_GOF, cif_values=cif_values_GOF,
gamma=ksTestGamma)
Processing given ISIs
Processing iter 0/9
Processing iter 1/9
Processing iter 2/9
Processing iter 3/9
Processing iter 4/9
Processing iter 5/9
Processing iter 6/9
Processing iter 7/9
Processing iter 8/9
Processing iter 9/9
Plot result of time-rescaling KS-test¶
fig = svGPFA.plot.plotUtilsPlotly.getPlotResKSTestTimeRescalingNumericalCorrection(diffECDFsX=diffECDFsX, diffECDFsY=diffECDFsY, estECDFx=estECDFx, estECDFy=estECDFy, simECDFx=simECDFx, simECDFy=simECDFy, cb=cb, title=title)
fig
Perform ROC predictive analysis¶
fpr, tpr, roc_auc = svGPFA.utils.miscUtils.computeSpikeClassificationROC(
spikes_times=spikes_times_GOF,
cif_times=trial_times_GOF,
cif_values=cif_values_GOF)
/nfs/ghome/live/rapela/dev/work/ucl/gatsby-swc/gatsby/svGPFA/repos/svGPFA/src/svGPFA/utils/miscUtils.py:262: UserWarning:
Found more than one spike in 1 bins
Plot result of ROC predictive analysis¶
fig = svGPFA.plot.plotUtilsPlotly.getPlotResROCAnalysis(
fpr=fpr, tpr=tpr, auc=roc_auc, title=title)
fig
6. Plotting¶
Imports for plotting¶
import numpy as np
import pandas as pd
import svGPFA.plot.plotUtilsPlotly
Set plotting variables¶
latent_to_plot = 0
latents_to_3D_plot = [0, 2, 4]
cluster_id_to_plot = 41
trial_to_plot = 0
choices_colors_patterns = ["rgba(0,0,255,{:f})", "rgba(255,0,0,{:f})"]
align_event_name = "response_times"
events_names = ["stimOn_times", "response_times", "stimOff_times"]
events_colors = ["magenta", "green", "black"]
events_markers = ["circle", "circle", "circle"]
cluster_id_to_plot_index = torch.nonzero(torch.IntTensor(selected_clusters_ids)==cluster_id_to_plot)
n_trials = len(spikes_times)
trials_choices = [trials["choice"][trial_id] for trial_id in selected_trials_ids]
trials_rewarded = [trials["feedbackType"][trial_id] for trial_id in selected_trials_ids]
trials_contrast = [trials["contrastRight"][trial_id]
if not np.isnan(trials["contrastRight"][trial_id])
else trials["contrastLeft"][trial_id]
for trial_id in selected_trials_ids]
trials_colors_patterns = [choices_colors_patterns[0]
if trials_choices[r] == -1
else choices_colors_patterns[1]
for r in range(n_trials)]
trials_colors = [trial_color_pattern.format(1.0)
for trial_color_pattern in trials_colors_patterns]
trials_annotations = {"choice": trials_choices,
"rewarded": trials_rewarded,
"contrast": trials_contrast,
"choice_prev": np.insert(trials_choices[:-1], 0, np.NAN),
"rewarded_prev": np.insert(trials_rewarded[:-1], 0,
np.NAN)}
events_times = []
for event_name in events_names:
events_times.append([trials[event_name][trial_id]
for trial_id in selected_trials_ids])
marked_events_times, marked_events_colors, marked_events_markers = \
iblUtils.buildMarkedEventsInfo(events_times=events_times,
events_colors=events_colors,
events_markers=events_markers)
align_event_times = [trials[align_event_name][trial_id]
for trial_id in selected_trials_ids]
Plot lower bound history¶
fig = svGPFA.plot.plotUtilsPlotly.getPlotLowerBoundHist(
elapsedTimeHist=elapsedTimeHist, lowerBoundHist=lowerBoundHist)
fig
Plot estimated latent across trials¶
testMuK, testVarK = model.predictLatents(times=trials_times)
fig = svGPFA.plot.plotUtilsPlotly.getPlotLatentAcrossTrials(
times=trials_times.numpy(),
latentsMeans=testMuK,
latentsSTDs=torch.sqrt(testVarK),
trials_ids=selected_trials_ids,
latentToPlot=latent_to_plot,
trials_colors_patterns=trials_colors_patterns,
xlabel="Time (msec)")
fig
Plot orthonormalized estimated latent across trials¶
testMuK, _ = model.predictLatents(times=trials_times)
test_mu_k_np = [testMuK[r].detach().numpy() for r in range(len(testMuK))]
estimatedC, estimatedD = model.getSVEmbeddingParams()
estimatedC_np = estimatedC.detach().numpy()
fig = svGPFA.plot.plotUtilsPlotly.getPlotOrthonormalizedLatentAcrossTrials(
trials_times=trials_times, latentsMeans=test_mu_k_np, latentToPlot=latent_to_plot,
align_event_times=align_event_times,
marked_events_times=marked_events_times,
marked_events_colors=marked_events_colors,
marked_events_markers=marked_events_markers,
trials_colors=trials_colors,
trials_annotations=trials_annotations,
C=estimatedC_np, trials_ids=selected_trials_ids,
xlabel="Time (msec)")
fig
Plot 3D scatter plot of orthonormalized latents¶
fig = svGPFA.plot.plotUtilsPlotly.get3DPlotOrthonormalizedLatentsAcrossTrials(
trials_times=trials_times.numpy(), latentsMeans=test_mu_k_np,
C=estimatedC_np, trials_ids=selected_trials_ids,
latentsToPlot=latents_to_3D_plot,
align_event_times=align_event_times,
marked_events_times=marked_events_times,
marked_events_colors=marked_events_colors,
marked_events_markers=marked_events_markers,
trials_colors=trials_colors,
trials_annotations=trials_annotations)
fig
Plot embedding¶
embeddingMeans, embeddingVars = model.predictEmbedding(times=trials_times)
embeddingMeans = embeddingMeans.detach().numpy()
embeddingVars = embeddingVars.detach().numpy()
title = "Neuron {:d}".format(cluster_id_to_plot)
fig = svGPFA.plot.plotUtilsPlotly.getPlotEmbeddingAcrossTrials(
times=trials_times.numpy(),
embeddingsMeans=embeddingMeans[:, :, cluster_id_to_plot_index],
embeddingsSTDs=np.sqrt(embeddingVars[:, :, cluster_id_to_plot_index]),
trials_colors_patterns=trials_colors_patterns,
title=title)
fig
Plot intensity functions for one neuron and all trials¶
title = f"Cluster ID: {clusters_ids[cluster_id_to_plot_index]}, Region: {locs_for_clusters_ids[cluster_id_to_plot]}"
fig = svGPFA.plot.plotUtilsPlotly.getPlotCIFsOneNeuronAllTrials(
trials_times=trials_times,
cif_values=cif_values,
neuron_index=cluster_id_to_plot_index,
spikes_times=spikes_times,
trials_ids=selected_trials_ids,
align_event_times=align_event_times,
marked_events_times=marked_events_times,
marked_events_colors=marked_events_colors,
marked_events_markers=marked_events_markers,
trials_annotations=trials_annotations,
trials_colors=trials_colors,
title=title)
fig
Plot orthonormalized embedding parameters¶
hovertemplate = "value: %{y}<br>" + \
"neuron index: %{x}<br>" + \
"%{text}"
text = [f"cluster_id: {cluster_id}" for cluster_id in selected_clusters_ids]
estimatedC, estimatedD = model.getSVEmbeddingParams()
fig = svGPFA.plot.plotUtilsPlotly.getPlotOrthonormalizedEmbeddingParams(
C=estimatedC.numpy(), d=estimatedD.numpy(),
hovertemplate=hovertemplate, text=text)
fig
Plot kernel parameters¶
kernelsParams = model.getKernelsParams()
kernelsTypes = [type(kernel).__name__ for kernel in model.getKernels()]
fig = svGPFA.plot.plotUtilsPlotly.getPlotKernelsParams(
kernelsTypes=kernelsTypes, kernelsParams=kernelsParams)
fig
To run the Python script or Jupyter notebook below, please download them to the examples/sphinx_gallery folder of the repository and execute them from there.
# sphinx_gallery_thumbnail_path = '_static/ibl_logo.png'
Total running time of the script: ( 0 minutes 33.704 seconds)